{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x10d5dea30>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import math\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dataset:\n",
    "    def __init__(self, vector_size=6, seed=None):\n",
    "        self.rng = np.random.RandomState(seed)\n",
    "        self.vector_size = vector_size\n",
    "        self.a_start = 0\n",
    "        self.a_end = 2\n",
    "        self.b_start = 4\n",
    "        self.b_end = 6\n",
    "        self.bias = self.rng.uniform(1, 11, size=(1, self.vector_size))\n",
    "    \n",
    "    def batch(self, batch_size=128):\n",
    "        v = self.rng.uniform(0, 0.1, size=(batch_size, self.vector_size)) + self.bias\n",
    "        a = np.sum(v[:, self.a_start:self.a_end], axis=1)\n",
    "        b = np.sum(v[:, self.b_start:self.b_end], axis=1)\n",
    "        t = a + b\n",
    "        \n",
    "        return (torch.tensor(v, dtype=torch.float32), torch.tensor(t[:, np.newaxis], dtype=torch.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GumbelNACLayer(torch.nn.Module):\n",
    "    \"\"\"Implements the Gumbel NAC (Neural Accumulator)\n",
    "\n",
    "    Arguments:\n",
    "        in_features: number of ingoing features\n",
    "        out_features: number of outgoing features\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, in_features, out_features):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "\n",
    "        self.tau = torch.nn.Parameter(torch.Tensor(1), requires_grad=False)\n",
    "        self.register_buffer('target_weights', torch.Tensor([1, -1, 0]))\n",
    "\n",
    "        self.W_hat = torch.nn.Parameter(torch.Tensor(out_features, in_features, 2))\n",
    "        self.register_buffer('W_hat_k', torch.Tensor(out_features, in_features, 1))\n",
    "        self.register_parameter('bias', None)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        # Initialize to zero, the source of randomness can come from the Gumbel sampling.\n",
    "        torch.nn.init.constant_(self.W_hat, 0)\n",
    "        torch.nn.init.constant_(self.W_hat_k, 0)\n",
    "        torch.nn.init.constant_(self.tau, 1)\n",
    "    \n",
    "    def set_tau(self, tau):\n",
    "        self.tau.fill_(tau)\n",
    "\n",
    "    def forward(self, input):\n",
    "        # Concat W_hat with a constant (W_hat_k), such that only two parameters controls\n",
    "        # 3 classes in the softmax.\n",
    "        W_hat_full = torch.cat((self.W_hat, self.W_hat_k), dim=-1)  # size = [out, in, 3]\n",
    "        \n",
    "        # Sample from gumbel-softmax depennding on W_hat_full, which have been\n",
    "        # turned into log properbilities.\n",
    "        log_pi = torch.nn.functional.log_softmax(W_hat_full, dim=-1)\n",
    "        y = torch.nn.functional.gumbel_softmax(log_pi.view(-1, 3), tau=self.tau).view(log_pi.size())\n",
    "        W = y @ self.target_weights\n",
    "        \n",
    "        return torch.nn.functional.linear(input, W, self.bias)\n",
    "\n",
    "    def extra_repr(self):\n",
    "        return 'in_features={}, out_features={}'.format(\n",
    "            self.in_features, self.out_features\n",
    "        )\n",
    "\n",
    "class Network(torch.nn.Module):\n",
    "    def __init__(self, model_name, vector_size=6):\n",
    "        super().__init__()\n",
    "        self.model_name = model_name\n",
    "\n",
    "        if model_name == 'GumbelNAC':\n",
    "            self.layer_1 = GumbelNACLayer(vector_size, 2)\n",
    "            self.layer_2 = GumbelNACLayer(2, 1)\n",
    "        elif model_name == 'linear':\n",
    "            self.layer_1 = torch.nn.Linear(vector_size, 2)\n",
    "            self.layer_2 = torch.nn.Linear(2, 1)\n",
    "        else:\n",
    "            raise NotImplemented(f'{model_name} is not implemented')\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.layer_1.reset_parameters()\n",
    "        self.layer_2.reset_parameters()\n",
    "    \n",
    "    def set_tau(self, tau):\n",
    "        if self.model_name == 'GumbelNAC':\n",
    "            self.layer_1.set_tau(tau)\n",
    "            self.layer_2.set_tau(tau)\n",
    "\n",
    "    def forward(self, input):\n",
    "        z_1 = self.layer_1(input)\n",
    "        z_2 = self.layer_2(z_1)\n",
    "        return z_2\n",
    "\n",
    "    def extra_repr(self):\n",
    "        return 'vector_size={}'.format(\n",
    "            self.vector_size\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train 0: 462.5357360839844\n",
      "train 1000: 717.829345703125\n",
      "train 2000: 647.0888061523438\n",
      "train 3000: 164.76223754882812\n",
      "train 4000: 99.82171630859375\n",
      "train 5000: 10.309021949768066\n",
      "train 6000: 124.35419464111328\n",
      "train 7000: 9.074092864990234\n",
      "train 8000: 1.4919761419296265\n"
     ]
    }
   ],
   "source": [
    "dataset = Dataset(vector_size=6, seed=0)\n",
    "model = Network('GumbelNAC', vector_size=6)\n",
    "model.reset_parameters()\n",
    "\n",
    "criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "for epoch_i in range(0, 100000):\n",
    "    model.set_tau(max(0.5, math.exp(-1e-5 * epoch_i)))\n",
    "    \n",
    "    # Prepear\n",
    "    x, t = dataset.batch()\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Loss\n",
    "    y = model(x)\n",
    "    loss = criterion(y, t)\n",
    "    \n",
    "    # Optimize\n",
    "    loss.backward()\n",
    "    optimizer.step()   \n",
    "    \n",
    "    if epoch_i % 1000 == 0:\n",
    "        print(f'train {epoch_i}: {loss}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "print(sys.version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
